use crate::encoder::{errors::SpannedEncodingResult, mir::types::MirTypeEncoderInterface};
use std::collections::{BTreeMap, BTreeSet};
use vir_crate::{
    high as vir_high,
    typed::{self as vir_typed, operations::HighToTypedTypeDecl},
};

#[derive(Default)]
pub(crate) struct HighToTypedTypeEncoderState {
    encoded_types: BTreeSet<vir_high::Type>,
    encoded_types_inverse: BTreeMap<vir_typed::Type, vir_high::Type>,
}

pub(crate) trait HighToTypedTypeEncoderInterface {
    fn encode_type_def_typed(
        &mut self,
        ty: &vir_typed::Type,
    ) -> SpannedEncodingResult<vir_typed::TypeDecl>;
    fn type_from_high_to_typed(
        &mut self,
        ty: vir_high::Type,
    ) -> SpannedEncodingResult<vir_typed::Type>;
    fn type_from_typed_to_high(
        &self,
        ty: &vir_typed::Type,
    ) -> SpannedEncodingResult<vir_high::Type>;
    fn generate_tuple_name(&self, arguments: &[vir_typed::Type]) -> SpannedEncodingResult<String>;
}

impl<'v, 'tcx: 'v> HighToTypedTypeEncoderInterface
    for super::super::super::super::Encoder<'v, 'tcx>
{
    fn encode_type_def_typed(
        &mut self,
        ty: &vir_typed::Type,
    ) -> SpannedEncodingResult<vir_typed::TypeDecl> {
        let high_type = &self.typed_type_encoder_state.encoded_types_inverse[ty];
        let type_decl_high = self.encode_type_def_high(high_type)?;
        type_decl_high.high_to_typed_type_decl(self)
    }

    fn type_from_high_to_typed(
        &mut self,
        ty: vir_high::Type,
    ) -> SpannedEncodingResult<vir_typed::Type> {
        if !self.typed_type_encoder_state.encoded_types.contains(&ty) {
            self.typed_type_encoder_state
                .encoded_types
                .insert(ty.clone());
            let low_type =
                vir_typed::operations::default_high_to_typed_type_type(ty.clone(), self)?;
            self.typed_type_encoder_state.encoded_types_inverse.insert(
                low_type.erase_lifetimes().erase_const_generics(),
                ty.erase_lifetimes().erase_const_generics(),
            );
            Ok(low_type)
        } else {
            vir_typed::operations::default_high_to_typed_type_type(ty, self)
        }
    }

    fn type_from_typed_to_high(
        &self,
        ty: &vir_typed::Type,
    ) -> SpannedEncodingResult<vir_high::Type> {
        // TODO: Remove duplication with decode_type_high in
        // prusti-viper/src/encoder/mir/types/interface.rs
        if let Some((ty_without_variant, variant)) = ty.split_off_variant() {
            let without_variant =
                self.typed_type_encoder_state.encoded_types_inverse[&ty_without_variant].clone();
            Ok(without_variant.variant(variant.index.clone().into()))
        } else if ty == &vir_typed::Type::Lifetime {
            unimplemented!("encode_type_high for lifetime {:?}", ty);
        } else if ty == &vir_typed::Type::Bool {
            // Bools may be generated by our encoding without having them in the
            // original program.
            Ok(vir_high::Type::Bool)
        } else if ty == &vir_typed::Type::Int(vir_typed::ty::Int::Usize) {
            // Usizes may be generated by our encoding without having them in
            // the original program.
            Ok(vir_high::Type::Int(vir_high::ty::Int::Usize))
        } else if let vir_typed::Type::Pointer(pointer) = ty {
            // We use pointer types for modelling addresses of references.
            let target_type = self.type_from_typed_to_high(&pointer.target_type)?;
            Ok(vir_high::Type::pointer(target_type))
        } else if let Some(ty) = self.typed_type_encoder_state.encoded_types_inverse.get(ty) {
            Ok(ty.clone())
        } else {
            unreachable!("failed to decode: {}\n{:?}", ty, ty)
        }
    }

    fn generate_tuple_name(&self, arguments: &[vir_typed::Type]) -> SpannedEncodingResult<String> {
        let mut name = "Tuple$".to_string();
        vir_typed::operations::identifier::common::append_type_arguments(&mut name, arguments);
        Ok(name)
    }
}
